Add NVFP4 per-token quantization recipe#3045
Conversation
Rewrites the grouped multi-tensor cast as a K1 fused amax + K2 fused cast
pair and ships pytest correctness + sweep benches against the per-tensor
RHT+SR production baseline.
* common/cast/.../quantize_nvfp4_per_token_group.cu: K1+K2 fused
grouped kernel, reusing the single-tensor 4-stage TMA pipeline.
* common/gemm/nvfp4_per_token_post_scale.cu: row-wise post-scale
kernel for the cuBLASLT NVFP4 dequantize step (maybe updated due
to 2d quant of W).
* pytorch/csrc/extensions/nvfp4_per_token.cpp + pybind.cpp: new C++
grouped bulk binding and per-token GEMM entry; thin pybind layer.
* pytorch/custom_recipes/{gemm_nvfp4_per_token,
quantization_nvfp4_per_token_group}.py: Python wrappers.
* tests/pytorch/nvfp4/test_nvfp4_per_token{,_group}.py: byte-equal
cast tests + bf16-close GEMM tests.
* tests/pytorch/nvfp4/bench_nvfp4_per_token{,_group}.py: 6x3 sweep
over M in {1024..32768} x K in {2048,4096,8192}, eager + CUDA
Graphs columns, ratio against per-tensor RHT+SR baseline.
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
6f17fe4 to
928ab1c
Compare
for more information, see https://pre-commit.ci
…uped) Wire `with_rht` / `random_sign_mask_t` through the per-token K1 (amax) and K2 (encode) kernels for both single-tensor and grouped paths. with_rht=False is byte-equal to the pre-RHT code path; when true, applies a 16-pt RHT on the columnwise direction in both K1 and K2 (rowwise stays raw) with outer amax + inner SF self-consistent. Implementation: per-thread fp32 FHT on CUDA cores, branchless fp32 sign-bit XOR for the +/-1 sign diagonal, 0.25 normalization folded into block_amax / block_scale (bit-exact). Tests cover K1, K2, composite + grouped vs a PyTorch fp32 reference and byte-equality regressions. Benches gain a --rht flag (2-way default, 3-way under --rht). Perf vs prod NVFP4Quantizer(rht+sr), Graph mode, 18 shapes M up to 32K: * single tensor : 0.49x-0.77x (no RHT), 0.59x-0.88x (+RHT) * grouped (N=8) : 0.41x-0.77x (no RHT), 0.50x-0.94x (+RHT) Also drops unused THREADS_X_TR / THREADS_Y_TR (nvcc warning NVIDIA#177-D). Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Add an optional fused-swizzle path to the NVFP4 per-token K2 encode
kernel: when with_swizzle=True the rowwise scale_inv is emitted directly
in the cuBLAS LT 128Mx4K swizzled tile layout, skipping the downstream
nvte_swizzle_scaling_factors launch. The colwise scale_inv stays in the
compact M-major layout (rowwise-only fusion for now).
The new code path is gated by a kWithSwizzle template parameter on
per_token_encode_kernel. The scatter epilogue uses thread mapping
b=tid&3, ty=tid>>2 to give each warp a coalesced 128-byte gmem store,
and packs two K-tiles into one uint64_t SMEM load (2-way bank conflict
instead of 4-way). Pre-existing code path is byte-equal.
with_swizzle is threaded through nvte_nvfp4_per_token_{quantize,encode},
their PyTorch bindings, and the nvfp4_per_token_{quantize,encode} Python
recipes. nvfp4_per_token_gemm takes new a_sf_swizzled / b_sf_swizzled
flags so the caller opts into the fast path per operand (mirrors prod
NVFP4 GEMM's per-operand swizzle).
Add tex.nvfp4_per_token_swizzle_rowwise_sf -- a thin wrapper around
nvte_swizzle_scaling_factors that does one standalone per-operand
swizzle launch. Bench-only; lets --qs attribute swizzle cost separately
from K1+K2 and from cuBLAS LT GEMM.
Bench (bench_nvfp4_per_token.py): add --qs mode (K1+K2 + standalone
swizzle, no GEMM) with two modifiers -- --pair (2 operands, matches one
prod GEMM call's quant+swizzle pipeline) and --fuse (adds a per-token
(fuse) column for the K2-fused path). The existing --swizzle end-to-end
mode also gains the fused-swizzle column. --pair / --fuse auto-imply
--qs to avoid silent fall-through to the default --composite table.
Tests (test_nvfp4_per_token.py): byte-equality of the fused-swizzle
rowwise SF vs a pure-Python permutation reference, byte-equality of all
other outputs (FP4 data, colwise SF, row/col amax) vs with_swizzle=False,
and numerical equivalence of the end-to-end GEMM via both code paths.
Perf at K=N=4096, Graph mode: fused-swizzle path is ~7-35% faster than
the unfused per-token pipeline (--qs) and reaches up to ~2.6x faster
than per-tensor at small M.
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
The per-token cuBLASLt NVFP4 path needs a trailing post-scale kernel
(D *= alpha_a[i] * alpha_b[j]) that is HBM-bound on the M*N output. This
patch ships a forked-CUTLASS NVFP4 GEMM whose EVT epilogue folds the
per-row * per-col rescale into the in-TMEM accumulator -- a single launch
with no separate post-scale, no M*N HBM round-trip.
New C-API entry points (transformer_engine/common/gemm/nvfp4_cutlass_gemm.cu):
- nvte_nvfp4_cutlass_gemm: scalar (alpha, beta) NVFP4xNVFP4 -> BF16 GEMM
(CUTLASS analog of the cuBLASLt per-tensor path; used as test ground truth).
- nvte_nvfp4_cutlass_per_token_gemm: same mainloop, EVT epilogue
D[i,j] = bf16(NVFP4_DEQUANT_K * alpha_a[i] * alpha_b[j] * acc).
The outer 1/2688^2 factor (NVFP4 spec) is baked into the EVT explicitly,
matching the value cuBLASLt auto-folds via its amax slot.
Python bindings (tex.nvfp4_cutlass_gemm / tex.nvfp4_cutlass_per_token_gemm)
plus a/b_sf_swizzled flags for apples-to-apples --gemm-only benching.
Numerical correctness (tests/pytorch/nvfp4/test_nvfp4_cutlass_per_token_gemm.py):
- fused EVT == cuBLASLt per-token within bf16 ULP (rtol=2e-2), across
M,N,K = 256..1024.
- fused EVT with unity alphas == nvfp4_cutlass_gemm(alpha=1/2688^2) BIT-EXACT
(sanity check that the EVT tree and the baked constant are both correct).
Bench (tests/pytorch/nvfp4/bench_nvfp4_per_token.py --gemm-only) streamlined
to the only comparison that matters for shipping: ct_fused (per-token CUTLASS
fused) vs pten_gemm (prod per-tensor cuBLASLt), with the cf/pten ratio.
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Extends tests/pytorch/nvfp4/{bench,test}_nvfp4_cutlass_per_token_gemm
with end-to-end forward and backward coverage that aligns the prod
baseline with NVFP4BlockScaling real-ship defaults (input RHT-1D,
weight 2D no-RHT, grad RHT-cols + SR), so per-token (no RHT/SR) is
measured against an actually-shippable prod recipe rather than a
toy quantizer.
bench_nvfp4_per_token.py:
* --e2e-fwd: per-token quant (with_swizzle=True) + fused-EVT CUTLASS
GEMM vs NVFP4Quantizer + general_gemm (the real nn.Linear fwd
dispatch). Quant + GEMM inside the timing loop, N = K. Function
docstring carries an ASCII kernel-pipeline diagram for both paths
(per-call launch budget: per-token ~5 vs prod ~10).
* --e2e-bwd: real prod nn.Linear.bwd lifecycle. Timing loop = 1 x dY
quant + dgrad GEMM + wgrad GEMM; X and W are pre-quantized OUTSIDE
the loop (mirrors prod's reuse of fwd-saved QuantizedTensorStorage,
bwd never re-quantizes). pten side uses RHT-cols + SR grad
quantizer + general_gemm NN (dgrad) / NT (wgrad). Function docstring
carries an ASCII kernel-pipeline diagram (per-step launch budget:
per-token ~4 vs prod ~12).
* --gemm-only: 3-way table adds an lt_post column (cuBLASLt NVFP4 +
bf16 per-row*per-col post-scale, "Route 1") next to the existing
ct_fused fused-EVT path ("Route 2") and the prod pten_gemm
baseline. Headline ratio lp/cf decides whether to dispatch
per-token through cuBLASLt + post_scale or fused EVT; current
data shows ct_fused wins or ties at every shape we care about.
test_nvfp4_cutlass_per_token_gemm.py:
* Layer 2 fwd: per-token quant + fused-EVT GEMM vs BF16 fp32 ground
truth (rel_l2 < 0.30, robust to per-shape noise).
* Layer 3 fwd: dual-SNR table comparing per-token vs prod, both
measured against BF16 ground truth, with a per-token-vs-prod ratio.
* Layer 3 bwd: same dual-SNR pattern for dgrad and wgrad. Prod side
uses real-ship NVFP4BlockScaling grad quantizer (RHT cols + SR);
per-token side has no RHT/SR (numerical-floor comparison).
* Sanity micro-test for weight 2D quant plumbing through general_gemm
(catches breakage cheaper than the broader Layer 3 test).
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
for more information, see https://pre-commit.ci
| DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); | ||
| constexpr int dshmem_size = buff_size_aligned_in + TMA_SHMEM_ALIGNMENT; // + align pad | ||
|
|
||
| dim3 grid(static_cast<unsigned>(K / CHUNK_DIM_X), static_cast<unsigned>(M / CHUNK_DIM_Y), 1); |
There was a problem hiding this comment.
maybe use DIVUP here to handle the remainder case?
There was a problem hiding this comment.
This fast path has a hard precondition that M and K are exact multiples of CHUNK_DIM (128): validate() does NVTE_CHECK(M % CHUNK_DIM_Y == 0) / NVTE_CHECK(K % CHUNK_DIM_X == 0), and is_supported() returns false unless both hold — so any non-multiple shape is rejected / routed to the generic per-token fallback before it ever reaches this launcher.
| // After all 4 stages, emit one atomicMaxFloat per row slot + one per col slot. | ||
| // | ||
| // kWithRht=true: col-wise amax over RHT-rotated 16-row strips (per-thread | ||
| // FHT with random_sign_mask_t). Row direction never sees RHT. |
There was a problem hiding this comment.
typo: Row direction never sees RHT -> Row direction never uses RHT
| } | ||
| } | ||
| #else | ||
| NVTE_DEVICE_ERROR("Per-token amax kernel requires SM 10.0+ (Blackwell)."); |
There was a problem hiding this comment.
For these quantization kernel, TMA only require SM 9.0+ only. Is there any other constraints that limit to sm 10.0+?
There was a problem hiding this comment.
The CUDA_ARCH >= 1000 guard is intentional but not because of a hardware op in this kernel. Two reasons:
- The shared TE PTX wrappers it calls — cp_async_bulk_tensor_2d_global_to_shared and mbarrier_wait_parity_acquire_cta_shared_cta in util/ptx.cuh — are themselves guarded to >= 1000 and emit NVTE_DEVICE_ERROR below that. They were authored/validated only for the Blackwell path.
- The whole NVFP4 quantize path is host-gated to SM100 anyway (NVTE_ERROR("NVFP4 requires SM100 ...")), since NVFP4 is a Blackwell datatype and the downstream FP4 GEMM that consumes these scales only exists on SM100. So the amax kernel is never launched off <SM100; the per-arch guard just yields a clean error instead of an undefined symbol.
Add NN/NT GEMM layout dispatch so the per-token NVFP4 path covers dgrad and wgrad, and let per-token opt into RHT via NVFP4PerTokenBlockScaling(per_token_rht=...) while SR/2D stay disabled (kernels unimplemented at this commit). Extends the per-token CUTLASS GEMM, the torch NVFP4Quantizer, and the NVFP4Tensor plumbing, plus dgrad/wgrad numerical tests and a fwd+bwd module smoke test. Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Thread a Philox rng_state and a kWithSr template flag through the per-token encode kernel (rowwise + colwise) and the nvte_nvfp4_per_token_encode/quantize C-API, mirroring the per-tensor SR path. Drop the SR mutex check in the torch NVFP4Quantizer and build the rng_state when stochastic rounding is requested. Add a per_token_sr recipe flag on NVFP4PerTokenBlockScaling wired through the quantizer factory, plus statistical tests (SR unbiasedness -- lower RMSE than RN when averaged -- and RN-determinism / SR-nondeterminism) folded into test_nvfp4_per_token.py. Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Wire with_sr + rng_state through the grouped per-token C-API and cast dispatch, implement the SR FP4 cast in the grouped kernel, and drop the "per-token does not support SR" guard. Also fix two comment typos (sees -> uses) in quantize_nvfp4_per_token.cu per review. Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Description
This PR adds an NVFP4 per-token quantization fast path for bf16 inputs on Blackwell (SM100+) for Model Pre-training. Per-token uses per-row / per-column outer amax instead of the per-tensor scalar amax, which factors cleanly out of the GEMM K-summation and lets the inner GEMM stay on production cuBLASLT NVFP4 plus a thin trailing post-scale.
Status: draft. The cast kernel and GEMM composite, byte-equal-verified against a Python reference, and benched against the per-tensor (RHT + SR) recipe are still in progress. Partial experimental results are shown as follows.
Tests and benches
This PR ships four new pytest / benchmark files under
tests/pytorch/nvfp4/. All four require bf16 input andM % 128 == 0/K % 128 == 0; GEMM tests are gated by SM100 (Blackwell).test_nvfp4_per_token.py— single-tensor correctnesstest_nvfp4_per_token_group.py— group-tensor correctnessbench_nvfp4_per_token.py— single-tensor amax + quant benchWall-time benchmark of the single-tensor amax + quant composite (
tex.nvfp4_per_token_quantize, this PR) against the per-tensor RHT+SR production baseline (NVFP4Quantizer(rht=True, sr=True)viatex.quantize). Both sides use rowwise + columnwise. Single output table:Each row reports eager wall-time plus CUDA Graphs replay (kernel-only floor).
ratio < 1.0⇒ per-token is faster than the per-tensor baseline.bench_nvfp4_per_token_group.py— grouped amax + quant benchWall-time benchmark of the grouped amax + quant composite (
nvfp4_per_token_group_quantizePython wrapper backed by the new C++ bulk entry, this PR) against the per-tensor RHT+SR grouped production baseline (tex.split_quantize(...)withNVFP4Quantizer(rht=True, sr=True)per split). Layout identical to the single-tensor bench:Default sweep is 6 × 3 = 18 cases at fixed
N = 8equal splits (MoE-typical):sum_M ∈ {1024, 2048, 4096, 8192, 16384, 32768}(so per-splitM_i ∈ {128 … 4096}) ×K ∈ {2048, 4096, 8192}. CUDA Graphs replay reported on every row.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: